{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "z9ln8ylCKHsx"
   },
   "source": [
    "# L3-E - Linear Quantization II: Quantizing Weights & Activations for Inference\n",
    "\n",
    "In this lesson, you will continue to learn different ways of performing linear quantization."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Run the next cell to import all of the functions you have used before in the previous lesson(s) of `Linear Quantization II` to follow along with the video.\n",
    "\n",
    "- To access the `helper.py` file, you can click `File --> Open...`, on the top left."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "from helper import linear_q_symmetric, get_q_scale_symmetric"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qUw1gQUu5yIe"
   },
   "source": [
    "## Linear Quantization: Inference"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- `W8A32` means weights in 8-bits and activations in 32-bits.\n",
    "- For simplicity, the linear layer will be without bias."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 165,
    "id": "vGll7vBT6BGI"
   },
   "outputs": [],
   "source": [
    "def quantized_linear_W8A32_without_bias(input, q_w, s_w, z_w):\n",
    "    assert input.dtype == torch.float32\n",
    "    assert q_w.dtype == torch.int8\n",
    "\n",
    "    dequantized_weight = q_w.to(torch.float32) * s_w + z_w\n",
    "    output = torch.nn.functional.linear(input, dequantized_weight)\n",
    "    \n",
    "    return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 113,
     "status": "ok",
     "timestamp": 1705361606028,
     "user": {
      "displayName": "Marc Sun",
      "userId": "00829270524676809963"
     },
     "user_tz": 300
    },
    "height": 29,
    "id": "7sPRcXM-AHTR"
   },
   "outputs": [],
   "source": [
    "input = torch.tensor([1, 2, 3], dtype=torch.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 133,
     "status": "ok",
     "timestamp": 1705361607599,
     "user": {
      "displayName": "Marc Sun",
      "userId": "00829270524676809963"
     },
     "user_tz": 300
    },
    "height": 63,
    "id": "o9IQsM1295iz",
    "outputId": "38f08506-db80-40ec-9840-e184a8a6fe5a"
   },
   "outputs": [],
   "source": [
    "weight = torch.tensor([[-2,   -1.13, 0.42],\n",
    "                       [-1.51, 0.25, 1.62],\n",
    "                       [0.23,  1.35, 2.15]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "q_w, s_w  = linear_q_symmetric(weight)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "q_w"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "s_w"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "output = quantized_linear_W8A32_without_bias(input,\n",
    "                                             q_w,\n",
    "                                             s_w,\n",
    "                                             0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "print(f\"This is the W8A32 output: {output}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "fp32_output = torch.nn.functional.linear(input, weight)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "print(f\"This is the output if we don't quantize: {fp32_output}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [
    "NgmjMISuIyzF",
    "3dMVgqcNJCqE",
    "MFx2m7RmzRd5",
    "l6XbkmOzYMrC",
    "8NS1TnQt6E6v"
   ],
   "provenance": [
    {
     "file_id": "12_pQW6LB80u98m72_YKwM4ph7eb5Uf3g",
     "timestamp": 1705360205453
    },
    {
     "file_id": "1U9pm4j_uAD8EO7OrEPvdpwFjQKO3suuH",
     "timestamp": 1700237940920
    }
   ]
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
